# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import copy
import os
import sys
import typing
import uuid
from collections.abc import AsyncGenerator
from collections.abc import Callable
from collections.abc import Sequence
from pathlib import Path
from unittest import mock

import pytest
import pytest_asyncio
from langchain_core.callbacks import AsyncCallbackManagerForLLMRun
from langchain_core.callbacks import AsyncCallbackManagerForToolRun
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGeneration
from langchain_core.outputs import ChatResult
from langchain_core.tools import BaseTool
from pydantic import BaseModel

TESTS_DIR = os.path.dirname(__file__)
PROJECT_DIR = os.path.dirname(TESTS_DIR)
SRC_DIR = os.path.join(PROJECT_DIR, "src")
EXAMPLES_DIR = os.path.join(PROJECT_DIR, "examples")
sys.path.append(SRC_DIR)

if typing.TYPE_CHECKING:
    from dask.distributed import LocalCluster
    from sqlalchemy.ext.asyncio import AsyncEngine

    from nat.data_models.intermediate_step import IntermediateStep
    from nat.profiler.intermediate_property_adapter import IntermediatePropertyAdaptor


@pytest.fixture(name="project_dir", scope='session')
def project_dir_fixture(root_repo_dir: Path) -> str:
    return str(root_repo_dir)


@pytest.fixture(name="test_data_dir")
def test_data_dir_fixture():
    return os.path.join(TESTS_DIR, "test_data")


@pytest.fixture(name="config_file")
def config_file_fixture(test_data_dir: str):
    return os.path.join(test_data_dir, "config.yaml")


@pytest.fixture(name="eval_config_file")
def eval_config_file_fixture() -> str:
    return os.path.join(EXAMPLES_DIR, "evaluation_and_profiling/simple_calculator_eval/configs/config-sizing-calc.yml")


@pytest.fixture(name="simple_config_file")
def simple_config_file_fixture() -> str:
    return os.path.join(EXAMPLES_DIR, "getting_started/simple_calculator/configs/config.yml")


@pytest.fixture(name="echo_config_file")
def echo_config_file_fixture(test_data_dir: str) -> str:
    return os.path.join(test_data_dir, "echo.yaml")


@pytest.fixture(name="mock_aiohttp_session")
def mock_aiohttp_session_fixture():
    with mock.patch("aiohttp.ClientSession") as mock_aiohttp_session:
        mock_aiohttp_session.return_value = mock_aiohttp_session
        mock_aiohttp_session.__aenter__.return_value = mock_aiohttp_session

        mock_get = mock.AsyncMock()
        mock_get.return_value = mock_get
        mock_get.__aenter__.return_value = mock_get
        mock_get.text.return_value = '<td data-testid="vuln-CWEs-link-0">test_output</td>'
        mock_get.json.return_value = {"test": "output"}
        mock_aiohttp_session.request.return_value = mock_get

        yield mock_aiohttp_session


@pytest.fixture(name="set_test_api_keys")
def set_test_api_keys_fixture(restore_environ):
    for key in ("NGC_API_KEY", "NVIDIA_API_KEY", "OPENAI_API_KEY"):
        os.environ[key] = "test_key"


@pytest.fixture(name="rapids_repo_names")
def rapids_repo_names_fixture() -> list[str]:
    return ["cugraph", "cuvs", "rmm", "raft", "cuspatial", "cuxfilter", "cucim"]


@pytest.fixture(name="rapids_repo_urls")
def rapids_repo_urls_fixture(rapids_repo_names: list[str]) -> dict[str, str]:
    return {repo: f"https://github.com/rapidsai/{repo}.git" for repo in rapids_repo_names}


@pytest.fixture(name="workflow_config")
def workflow_config_fixture():
    from _utils.configs import WorkflowTestConfig
    return WorkflowTestConfig(llm_name='test_llm', functions=['test_function'], prompt='Are you a unittest?')


@pytest.fixture(name="tools_config")
def tools_config_fixture() -> dict[str, typing.Any]:
    return {
        "test_function": {
            "_type": "test_function"
        },
        "test_tool_2": {
            "_type": "test_function"
        },
        "test_tool_3": {
            "_type": "test_function"
        },
    }


@pytest.fixture(name="llms_config")
def llms_config_fixture() -> dict[str, typing.Any]:
    return {"test_llm": {"_type": "test_llm"}, "test_llm_2": {"_type": "test_llm"}, "test_llm_3": {"_type": "test_llm"}}


class StreamingOutputModel(BaseModel):
    result: str


class SingleOutputModel(BaseModel):
    summary: str


@pytest.fixture(name="test_workflow_fn")
def test_workflow_fn_fixture():

    async def workflow_fn(_param: BaseModel) -> SingleOutputModel:
        return SingleOutputModel(summary="This is a coroutine function")

    return workflow_fn


@pytest.fixture(name="test_streaming_fn")
def test_streaming_fn_fixture():

    async def streaming_fn(_param: BaseModel) -> typing.Annotated[AsyncGenerator[StreamingOutputModel], ...]:
        yield StreamingOutputModel(result="this is an async generator")

    return streaming_fn


@pytest.fixture(name="register_test_workflow")
def register_test_workflow_fixture(test_workflow_fn) -> Callable[[], Callable]:

    def register_test_workflow():
        from _utils.configs import WorkflowTestConfig
        from nat.builder.builder import Builder
        from nat.cli.register_workflow import register_function

        @register_function(config_type=WorkflowTestConfig)
        async def build_fn(_: WorkflowTestConfig, __: Builder):
            yield test_workflow_fn

        return build_fn

    return register_test_workflow


@pytest.fixture(name="reactive_stream")
def reactive_stream_fixture():
    """
    A fixture that sets up a fresh usage_stats queue in the context var
    for each test, then resets it afterward.
    """
    from nat.builder.context import ContextState
    from nat.utils.reactive.subject import Subject

    token = None
    original_queue = ContextState.get().event_stream.get()

    try:
        new_queue = Subject()
        token = ContextState.get().event_stream.set(new_queue)
        yield new_queue
    finally:
        if token is not None:
            # Reset to the original queue after the test
            ContextState.get().event_stream.reset(token)
            ContextState.get().event_stream.set(original_queue)


@pytest.fixture(name="global_settings", scope="function", autouse=False)
def function_settings_fixture():
    """
    Resets and returns the global settings for testing.

    This gets automatically used at the function level to ensure no state is leaked between functions.
    """

    from nat.settings.global_settings import GlobalSettings

    with GlobalSettings.push() as settings:
        yield settings


@pytest.fixture(name="pypi_registry_channel")
def pypi_registry_channel_fixture():
    """
    Returns a pypi registry channel configuration.
    """
    return {
        "channels": {
            "pypi_channel": {
                "_type": "pypi",
                "endpoint": "http://localhost:1234",
                "publish_route": "",
                "pull_route": "",
                "search_route": "simple",
                "token": "test-token"
            }
        }
    }


@pytest.fixture(name="rest_registry_channel")
def rest_registry_channel_fixture():
    """
    Returns a rest registry channel configuration.
    """
    return {
        "channels": {
            "rest_channel": {
                "_type": "rest",
                "endpoint": "http://localhost:1234",
                "publish_route": "publish",
                "pull_route": "pull",
                "search_route": "search",
                "remove_route": "remove",
                "token": "test-token"
            }
        }
    }


@pytest.fixture(name="local_registry_channel")
def local_registry_channel_fixture():
    """
    Returns a local registry channel configuration.
    """
    return {"channels": {"local_channel": {"_type": "local"}}}


@pytest.fixture(scope="session")
def httpserver_listen_address():

    return "127.0.0.1", 0


@pytest.fixture(scope="module")
async def mock_llm():

    class MockLLM(BaseChatModel):

        async def _agenerate(self,
                             messages: list[BaseMessage],
                             stop: list[str] | None = None,
                             run_manager: AsyncCallbackManagerForLLMRun | None = None,
                             **kwargs: typing.Any) -> ChatResult:
            # mock behavior to test agent features
            if len(messages) == 1:
                if 'mock tool call' in messages[0].content:
                    message = AIMessage(content='mock tool call',
                                        response_metadata={"mock_llm_response": True},
                                        tool_calls=[{
                                            "name": "Tool A",
                                            "args": {
                                                "query": "mock query"
                                            },
                                            "id": "Tool A",
                                            "type": "tool_call",
                                        }])
                    generation = ChatGeneration(message=message)
                    return ChatResult(generations=[generation], llm_output={'mock_llm_response': True})
            if len(messages) == 4:
                if 'fix the input on retry' in messages[2].content:
                    response = 'Thought: not many\nAction: Tool A\nAction Input: give me final answer!\nObservation:'
                    message = AIMessage(content=response, response_metadata={"mock_llm_response": True})
                    generation = ChatGeneration(message=message)
                    return ChatResult(generations=[generation], llm_output={'mock_llm_response': True})
                if 'give me final answer' in messages[3].content:
                    response = 'Final Answer: hello, world!'
                    message = AIMessage(content=response, response_metadata={"mock_llm_response": True})
                    generation = ChatGeneration(message=message)
                    return ChatResult(generations=[generation], llm_output={'mock_llm_response': True})
            message = AIMessage(content=messages[-1].content, response_metadata={"mock_llm_response": True})
            generation = ChatGeneration(message=message)
            return ChatResult(generations=[generation], llm_output={'mock_llm_response': True})

        def _generate(self,
                      messages: list[BaseMessage],
                      stop: list[str] | None = None,
                      run_manager: CallbackManagerForLLMRun | None = None,
                      **kwargs: typing.Any) -> ChatResult:
            message = AIMessage(content=messages[-1].content, response_metadata={"mock_llm_response": True})
            generation = ChatGeneration(message=message)
            return ChatResult(generations=[generation], llm_output={'mock_llm_response': True})

        def bind_tools(
                self,
                tools: Sequence[dict[str, typing.Any] | type | Callable | BaseTool],  # noqa: UP006
                **kwargs: typing.Any) -> BaseChatModel:
            return self

        @property
        def _llm_type(self) -> str:
            return 'mock-llm'

    return MockLLM()


@pytest.fixture(scope="module")
def mock_tool():

    def _create_mock_tool(tool_name: str):

        class MockTool(BaseTool):
            name: str = tool_name
            description: str = 'test tool:' + tool_name

            async def _arun(self,
                            query: str | dict = 'test',
                            run_manager: AsyncCallbackManagerForToolRun | None = None,
                            **kwargs):  # noqa: E501
                return query

            def _run(self,
                     query: str | dict = 'test',
                     run_manager: CallbackManagerForToolRun | None = None,
                     **kwargs):  # noqa: E501
                return query

        return MockTool()

    return _create_mock_tool


@pytest.fixture(name="rag_user_inputs")
def rag_user_inputs_fixture() -> list[str]:
    """Fixture providing multiple user inputs."""
    return ["What is ML?", "What is NLP?"]


@pytest.fixture(name="rag_generated_outputs")
def rag_generated_outputs_fixture() -> list[str]:
    """Fixture providing workflow generated outputs corresponding to user inputs."""
    return ["ML is the abbreviation for Machine Learning", "NLP stands for Natural Language Processing"]


@pytest.fixture(name="rag_intermediate_steps")
def rag_intermediate_steps_fixture(rag_user_inputs, rag_generated_outputs) -> list[list["IntermediateStep"]]:
    """
    Fixture to generate separate lists of IntermediateStep objects for each user input.

    Each list includes:
    1. LLM_START, LLM_NEW_TOKENs, LLM_END
    2. TOOL_START, and TOOL_END.

    Returns:
        (list for user_input_1, list for user_input_2)
    """
    from nat.builder.framework_enum import LLMFrameworkEnum
    from nat.data_models.intermediate_step import IntermediateStep
    from nat.data_models.intermediate_step import IntermediateStepPayload
    from nat.data_models.intermediate_step import IntermediateStepType
    from nat.data_models.intermediate_step import StreamEventData
    from nat.data_models.invocation_node import InvocationNode

    framework = LLMFrameworkEnum.LANGCHAIN
    token_cnt = 10
    llm_name = "mock_llm"
    tool_name = "mock_tool"

    def create_step(event_type,
                    name=llm_name,
                    input_data=None,
                    output_data=None,
                    chunk=None,
                    step_uuid: str | None = None):
        """Helper to create an `IntermediateStep`."""
        if step_uuid is None:
            step_uuid = str(uuid.uuid4())
        return IntermediateStep(parent_id="root",
                                function_ancestry=InvocationNode(function_name=name,
                                                                 function_id=f"test-{name}-{step_uuid}"),
                                payload=IntermediateStepPayload(UUID=step_uuid,
                                                                event_type=event_type,
                                                                framework=framework,
                                                                name=name,
                                                                data=StreamEventData(input=input_data,
                                                                                     output=output_data,
                                                                                     chunk=chunk)))

    step_lists = []  # Store separate lists

    for user_input, generated_ouput in zip(rag_user_inputs, rag_generated_outputs):
        tool_input = f"Get me the documents for {user_input}"
        tool_output = f"Here is information I have on {user_input}"
        generated_output = generated_ouput

        llm_start_step = create_step(IntermediateStepType.LLM_START, input_data=user_input)

        steps = [
            llm_start_step,
            *[
                create_step(IntermediateStepType.LLM_NEW_TOKEN, chunk=f"Token {i} for {user_input}")
                for i in range(token_cnt)
            ],
            create_step(IntermediateStepType.LLM_END,
                        input_data=user_input,
                        output_data=generated_output,
                        step_uuid=llm_start_step.UUID)
        ]

        tool_start_step = create_step(IntermediateStepType.TOOL_START, name=tool_name, input_data=tool_input)
        steps.append(tool_start_step)

        steps.append(
            create_step(IntermediateStepType.TOOL_END,
                        name=tool_name,
                        input_data=tool_input,
                        output_data=tool_output,
                        step_uuid=tool_start_step.UUID))

        step_lists.append(steps)  # Append separate list for each user input

    return step_lists


@pytest.fixture(name="rag_intermediate_property_adaptor")
def rag_intermediate_property_adaptor_fixture(rag_intermediate_steps) -> list[list["IntermediatePropertyAdaptor"]]:
    """
    Fixture to transform the rag_intermediate_steps fixture data into IntermediatePropertyAdaptor objects.
    """
    from nat.profiler.intermediate_property_adapter import IntermediatePropertyAdaptor

    return [[IntermediatePropertyAdaptor.from_intermediate_step(step) for step in steps]
            for steps in rag_intermediate_steps]


@pytest.fixture(name="dask_cluster", scope="session")
def dask_cluster_fixture(fail_missing: bool) -> "LocalCluster":
    """
    Fixture to provide a Dask LocalCluster for tests.
    """
    try:
        from dask.distributed import LocalCluster
    except ImportError:
        if fail_missing:
            raise
        pytest.skip("Dask is not installed, skipping Dask cluster fixture.")

    cluster = LocalCluster(asynchronous=False, n_workers=1, threads_per_worker=1)
    yield cluster
    cluster.close()


@pytest.fixture(name="dask_scheduler_address", scope="session")
def dask_scheduler_address_fixture(dask_cluster: "LocalCluster") -> str:
    """
    Fixture to provide the Dask scheduler address for tests.
    """
    return dask_cluster.scheduler.address


@pytest.fixture(name="db_engine")
def db_engine_fixture(fail_missing: bool, tmp_path: Path) -> "AsyncEngine":
    """
    Fixture to provide a SQLAlchemy AsyncEngine connected to a temporary SQLite database for tests.
    """
    try:
        from sqlalchemy.ext.asyncio import create_async_engine
    except ImportError:
        if fail_missing:
            raise
        pytest.skip("SQLAlchemy is not installed, skipping database engine fixture.")

    db_path = tmp_path / "test_db.sqlite"
    db_url = f"sqlite+aiosqlite:///{db_path}"
    db_engine = create_async_engine(db_url, echo=False, future=True)
    return db_engine


@pytest_asyncio.fixture(name="setup_db")
async def setup_db_fixture(db_engine: "AsyncEngine"):
    """
    Fixture to create database tables before tests and drop them afterward.
    """
    from nat.front_ends.fastapi.job_store import Base

    async with db_engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all, checkfirst=True)


@pytest.fixture(name="db_url")
def db_url_fixture(db_engine: "AsyncEngine") -> str:
    """
    Fixture to provide the database URL for the tests.
    """
    return str(db_engine.url)


@pytest.fixture(name="set_nat_config_file_env_var")
def fixture_set_nat_config_file_env_var(restore_environ, echo_config_file: str) -> str:
    """
    Fixture to set the NAT_CONFIG_FILE environment variable for tests.
    This ensures that tests have a consistent configuration file path.
    """
    os.environ["NAT_CONFIG_FILE"] = echo_config_file
    return echo_config_file


@pytest.fixture(name="set_nat_dask_scheduler_env_var")
def fixture_set_nat_dask_scheduler_env_var(restore_environ, dask_scheduler_address: str) -> str:
    """
    Fixture to set the NAT_DASK_SCHEDULER_ADDRESS environment variable for tests.
    This ensures that tests have a consistent Dask scheduler address.
    """
    os.environ["NAT_DASK_SCHEDULER_ADDRESS"] = dask_scheduler_address
    return dask_scheduler_address


@pytest.fixture(name="set_nat_job_store_db_url_env_var")
def fixture_set_nat_job_store_db_url_env_var(restore_environ, db_url: str) -> str:
    """
    Fixture to set the NAT_JOB_STORE_DB_URL environment variable for tests.
    This ensures that tests have a consistent job store database URL.
    """
    os.environ["NAT_JOB_STORE_DB_URL"] = db_url
    return db_url


@pytest.fixture(name="register_empty_function", scope="session", autouse=True)
def register_empty_function_fixture():
    from nat.builder.builder import Builder
    from nat.cli.register_workflow import register_function
    from nat.data_models.function import EmptyFunctionConfig

    @register_function(config_type=EmptyFunctionConfig)
    async def empty_function(config: EmptyFunctionConfig, builder: Builder):

        async def inner(*_, **__):
            return

        yield inner


@pytest.fixture(name="reset_global_type_converter")
def reset_global_type_converter_fixture():
    """
    Restore the GlobalTypeConverter to its previous state after a test that manipulates it in some way.
    """
    from nat.utils.type_converter import GlobalTypeConverter

    orig_converters = copy.deepcopy(GlobalTypeConverter.get()._converters)

    yield
    GlobalTypeConverter.get()._converters = orig_converters
